import os
import time
import torch
import torch.nn.functional as F
import logging
import numpy as np
import torch.nn as nn
from copy import deepcopy
from sklearn.neighbors import NearestNeighbors
from timm.models import create_model

# from src.utils_data_ner import NER_collate_fn
# from src.utils_data_tc import TC_collate_fn
from src.utils_others import fisher_matrix_diag, compute_class_feature_center, compute_feature_by_dataloader, get_match_id
from src.utils_buffer import Buffer
# from src.models_tc import TC_model
from src.models_ic import IC_model
# from src.models_ner import NER_model
from src.models_classifier import SplitCosineLinear

logger = logging.getLogger()

class BaseTrainer(object):
    def __init__(self, params, CL_dataset):
        # parameters
        self.params = params
        self.CL_dataset = CL_dataset
        self.model = None
        
        # training
        self.lr = float(params.lr)
        self.mu = 0.9
        self.weight_decay = 5e-4

    def begin_task(self, task_id):

        self.build_model(task_id)
        self.build_classifier(task_id)
        self.model.cuda()

        if self.params.is_CLSER:
            self.refer_model.cuda()
            self.refer_model_2.cuda()

        self.build_optimizer(task_id)

        if self.params.is_buffer:
            self.build_buffer(task_id)

    def end_task(self, task_id):
        if self.params.is_EWC:
            if task_id>0: 
                fisher_old={}
                for n,_ in self.model.named_parameters():
                    if 'encoder' in n:
                        fisher_old[n]=self.fisher[n].clone()
            self.fisher=fisher_matrix_diag(self.model, self.CL_dataset.data_loader['train'][task_id])
            if task_id>0:
                # Watch out! We do not want to keep t models (or fisher diagonals) in memory, 
                # therefore we have to merge fisher diagonals
                for n,_ in self.model.named_parameters():
                    if 'encoder' in n:
                        self.fisher[n]=(self.fisher[n]+fisher_old[n]*task_id)/(task_id+1)  
        if self.params.is_buffer: 
            self.model.buffer.finish_end_task(task_id, 
                                            self.CL_dataset.data_loader['train'][task_id], 
                                            self.model,
                                            self.CL_dataset.CUR_CLASS,
                                            self.params.task_name)

    def build_buffer(self, task_id):
        if task_id==0:
            self.model.buffer = Buffer(self.params.buffer_size, self.CL_dataset.CUR_NUM_CLASS, 
                                       self.params.batch_size, self.params.sampling_alg, 
                                       self.params.is_fix_budget_each_class, self.params.is_mix_er)
        self.model.buffer.init_begin_task(task_id)

    def build_model(self, task_id):
        '''
            Build model
        '''
        # Initialize a new model
        if task_id == 0:
            # Initialize the model for the first group of classes
            if self.params.task_name == 'NER':
                pass
                # model = NER_model(output_dim=self.CL_dataset.ACCUM_NUM_CLASS[task_id], params=self.params)
            elif self.params.task_name == 'TC':
                pass
                # model = TC_model(output_dim=self.CL_dataset.ACCUM_NUM_CLASS[task_id], params=self.params)
            elif self.params.task_name == 'IC':
                model = IC_model(output_dim=self.CL_dataset.ACCUM_NUM_CLASS[task_id], params=self.params)
            else:
                raise NotImplementedError()
            self.model = model
            self.refer_model = None
            if self.params.is_MEMO:
                self.model.encoder_specific.append(create_model(model_name=self.params.backbone+'_specific', # 'vit_base_patch16_224', 'vit_tiny_patch16_224_in21k'
                                            pretrained=True,
                                            num_classes=1000,
                                            drop_rate=0.0,
                                            drop_path_rate=0.0,
                                            drop_block_rate=None))
            if self.params.is_CLSER:
                self.refer_model = deepcopy(self.model)
                self.refer_model_2 = deepcopy(self.model)
        else:
            if self.params.is_MEMO:
                self.model.encoder_specific.append(create_model(model_name=self.params.backbone+'_specific', # 'vit_base_patch16_224', 'vit_tiny_patch16_224_in21k'
                                            pretrained=True,
                                            num_classes=1000,
                                            drop_rate=0.0,
                                            drop_path_rate=0.0,
                                            drop_block_rate=None))
                self.model.encoder_specific[-1].load_state_dict(self.model.encoder_specific[-2].state_dict())
            if not (self.params.is_CLSER and self.params.is_MEMO):
                self.refer_model = deepcopy(self.model)
                self.refer_model.eval()

        if self.params.is_BaCE:
            self.model.score_dict = {}
            self.model.knn_id_dict = {} 
            self.model.knn_dist_dict = {} 

            if self.params.BaCE_prompt_tuning and task_id>0:
                # Only Prompt Tuning for incremental tasks
                total_num_prompt = (self.CL_dataset.NUM_TASK-1)*self.params.BaCE_prompt_len
                self.model.soft_prompt = nn.Parameter(torch.randn((total_num_prompt,self.model.encoder.embed_dim)))
                nn.init.uniform_(self.model.soft_prompt, -1, 1)
                for n, p in self.model.encoder.named_parameters():
                    p.requires_grad = False

    def build_classifier(self, task_id):
        '''
            Update the architecture of the classifier
        '''
        if task_id == 0:
            hidden_dim = self.model.classifier.hidden_dim
            output_dim = self.model.classifier.output_dim
            logger.info("hidden_dim=%d, output_dim=%d"%(hidden_dim,output_dim))
            self.one_hidden_dim = hidden_dim
        elif task_id == 1:
            hidden_dim = self.model.classifier.hidden_dim
            output_dim = self.model.classifier.output_dim
            logger.info("hidden_dim=%d, old_output_dim=%d, new_output_dim=%d"%(
                                        hidden_dim,
                                        output_dim,
                                        self.CL_dataset.CUR_NUM_CLASS[task_id]))
            if self.params.is_MEMO:
                new_fc = SplitCosineLinear(hidden_dim+self.one_hidden_dim, output_dim, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc.fc0.weight.data[:,:-self.one_hidden_dim] = self.model.classifier.weight.data
                new_fc.sigma.data = self.model.classifier.sigma.data
                self.model.classifier = new_fc
            else:
                new_fc = SplitCosineLinear(hidden_dim, output_dim, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc.fc0.weight.data = self.model.classifier.weight.data
                new_fc.sigma.data = self.model.classifier.sigma.data
                self.model.classifier = new_fc

            if self.params.is_CLSER:
                new_fc_1 = SplitCosineLinear(hidden_dim, output_dim, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc_1.fc0.weight.data = self.refer_model.classifier.weight.data
                new_fc_1.sigma.data = self.refer_model.classifier.sigma.data
                self.refer_model.classifier = new_fc_1
                new_fc_2 = SplitCosineLinear(hidden_dim, output_dim, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc_2.fc0.weight.data = self.refer_model_2.classifier.weight.data
                new_fc_2.sigma.data = self.refer_model_2.classifier.sigma.data
                self.refer_model_2.classifier = new_fc_2
        else:
            hidden_dim = self.model.classifier.hidden_dim
            output_dim0 = self.model.classifier.fc0.output_dim
            output_dim1 = self.model.classifier.fc1.output_dim
            logger.info("hidden_dim=%d, old_output_dim=%d, new_output_dim=%d"%(
                                                            hidden_dim,
                                                            output_dim0+output_dim1,
                                                            self.CL_dataset.CUR_NUM_CLASS[task_id]))                                                
            
            if self.params.is_MEMO:
                new_fc = SplitCosineLinear(hidden_dim+self.one_hidden_dim, output_dim0+output_dim1, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc.fc0.weight.data[:output_dim0,:-self.one_hidden_dim] = self.model.classifier.fc0.weight.data
                new_fc.fc0.weight.data[output_dim0:,:-self.one_hidden_dim] = self.model.classifier.fc1.weight.data
                new_fc.sigma.data = self.model.classifier.sigma.data
                self.model.classifier = new_fc
            else:
                new_fc = SplitCosineLinear(hidden_dim, output_dim0+output_dim1, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc.fc0.weight.data[:output_dim0] = self.model.classifier.fc0.weight.data
                new_fc.fc0.weight.data[output_dim0:] = self.model.classifier.fc1.weight.data
                new_fc.sigma.data = self.model.classifier.sigma.data
                self.model.classifier = new_fc

            if self.params.is_CLSER:
                new_fc_1 = SplitCosineLinear(hidden_dim, output_dim0+output_dim1, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc_1.fc0.weight.data[:output_dim0] = self.refer_model.classifier.fc0.weight.data
                new_fc_1.fc0.weight.data[output_dim0:] = self.refer_model.classifier.fc1.weight.data
                new_fc_1.sigma.data = self.refer_model.classifier.sigma.data
                self.refer_model.classifier = new_fc_1
                new_fc_2 = SplitCosineLinear(hidden_dim, output_dim0+output_dim1, self.CL_dataset.CUR_NUM_CLASS[task_id])
                new_fc_2.fc0.weight.data[:output_dim0] = self.refer_model_2.classifier.fc0.weight.data
                new_fc_2.fc0.weight.data[output_dim0:] = self.refer_model_2.classifier.fc1.weight.data
                new_fc_2.sigma.data = self.refer_model_2.classifier.sigma.data
                self.refer_model_2.classifier = new_fc_2

        if self.params.task_name == 'NER':
            logger.info("All seen entity types = %s"%str(self.CL_dataset.ACCUM_ENTITY[task_id]))
            logger.info("New entity types = %s"%str(self.CL_dataset.CUR_ENTITY[task_id]))
        elif self.params.task_name == 'TC':
            logger.info("All seen class = %s"%(self.CL_dataset.CLASSNAME_LIST[:self.CL_dataset.ACCUM_NUM_CLASS[task_id]]))
            logger.info("New classes = %s"%(self.CL_dataset.CLASSNAME_LIST[self.CL_dataset.PRE_ACCUM_NUM_CLASS[task_id]:self.CL_dataset.ACCUM_NUM_CLASS[task_id]]))
        elif self.params.task_name == 'IC':
            logger.info("All seen class = %s"%(self.CL_dataset.LABEL_LIST[:self.CL_dataset.ACCUM_NUM_CLASS[task_id]]))
            logger.info("New classes = %s"%(self.CL_dataset.LABEL_LIST[self.CL_dataset.PRE_ACCUM_NUM_CLASS[task_id]:self.CL_dataset.ACCUM_NUM_CLASS[task_id]]))
        else:
            raise NotImplementedError()
        
        # Imprint: initialize the new classifier
        if task_id>0 and self.params.is_imprint:  
            # (1) compute the average norm of old embdding
            old_embedding_norm = self.model.classifier.fc0.weight.data.norm(dim=1, keepdim=True)
            average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).cpu().type(torch.DoubleTensor)
            # (2) compute class centers for each new classes (B-/I-)
            class_center_matrix = compute_class_feature_center(self.CL_dataset.data_loader['train'][task_id], 
                                        model=self.model, 
                                        select_class_indexes=self.CL_dataset.CUR_CLASS[task_id], 
                                        is_normalize=True)
            # (3) rescale the norm for each classes (each row) 
            rescale_weight_matrix = F.normalize(class_center_matrix, p=2, dim=-1) * average_old_embedding_norm
            nan_pos_list = torch.where(torch.isnan(rescale_weight_matrix[:,0]))[0]
            for nan_pos in nan_pos_list:
                assert nan_pos%2==1, "Class not appear in dataloader!!!"
                # replace the weight of I- with B-
                rescale_weight_matrix[nan_pos] = rescale_weight_matrix[nan_pos-1].clone()
            self.model.classifier.fc1.weight.data = rescale_weight_matrix.type(torch.FloatTensor)

    def build_optimizer(self, task_id):
        '''
            Build optimizer
        '''
        # build optimizer
        tg_params = []

        if task_id==0:
            # Classifier
            tg_params.append({'params': self.model.classifier.parameters(), 'lr': float(self.params.final_fc_lr),
                    'weight_decay': 0.})

            # Encoder backbone
            if self.params.is_MEMO:
                tg_params.append(
                    {'params': self.model.encoder_general.parameters(), 'lr': float(self.params.lr),
                    'weight_decay': float(self.params.weight_decay)})
                tg_params.append(
                    {'params': self.model.encoder_specific[-1].parameters(), 'lr': float(self.params.lr),
                    'weight_decay': float(self.params.weight_decay)})
            elif not self.params.is_fix_enc:
                if self.params.use_adapter:
                    tg_params.append(
                        {'params': self.model.encoder.adapter_parameters(), 'lr': float(self.params.lr),
                        'weight_decay': float(self.params.weight_decay)}, 
                    )
                else:
                    tg_params.append(
                        {'params': self.model.encoder.parameters(), 'lr': float(self.params.lr),
                        'weight_decay': float(self.params.weight_decay)}, 
                    )
            
        elif task_id>0:

            # Classifier
            if self.params.is_fix_old_cls:
                tg_params.append(
                    {'params': self.model.classifier.fc1.parameters(), 'lr': float(self.params.final_fc_lr),
                    'weight_decay': 0.})
            else:
                tg_params.append({'params': self.model.classifier.parameters(), 'lr': float(self.params.final_fc_lr),
                    'weight_decay': 0.})
                
            # Encoder backbone
            if self.params.is_MEMO:
                tg_params.append(
                    {'params': self.model.encoder_general.parameters(), 'lr': float(self.params.lr),
                    'weight_decay': float(self.params.weight_decay)})
                tg_params.append(
                    {'params': self.model.encoder_specific[-1].parameters(), 'lr': float(self.params.lr),
                    'weight_decay': float(self.params.weight_decay)})
            elif self.params.is_BaCE and self.params.BaCE_prompt_tuning:
                # Prompt Tuning fixes the backbone model
                pass 
            elif not self.params.is_fix_enc:
                if self.params.use_adapter:
                    tg_params.append(
                        {'params': self.model.encoder.adapter_parameters(), 'lr': float(self.params.lr),
                        'weight_decay': float(self.params.weight_decay)}, 
                    )
                else:
                    tg_params.append(
                        {'params': self.model.encoder.parameters(), 'lr': float(self.params.lr),
                        'weight_decay': float(self.params.weight_decay)}, 
                    )

            # Prompt
            if self.params.is_BaCE and self.params.BaCE_prompt_tuning:
                tg_params.append(
                    {'params': self.model.soft_prompt, 'lr': float(self.params.final_fc_lr),
                    'weight_decay': 0.}
                )

        self.optimizer = torch.optim.SGD(tg_params, momentum=self.params.mu)
        
        if self.params.scheduler == 'constant':
            self.scheduler = None
        elif self.params.scheduler == 'multistep':
            milestone = eval(self.params.scheduler_milestone)
            self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,milestones=milestone,gamma=0.1)
        else:
            raise NotImplementedError()

    def observe_batch(self, idx, X, y, task_id, epoch_id, global_step, is_replay=False):

        y = y.long()

        # forward
        if self.params.is_BaCE and self.params.BaCE_prompt_tuning and task_id>0:
            logits, features = self.model.forward(X, return_feat=True, task_id=task_id)
        else:
            logits, features = self.model.forward(X, return_feat=True)

        # Compute loss
        if task_id==0 or is_replay:
            total_loss, ce_loss, distill_loss = self.model.batch_loss(logits, y)
        else:
            if self.params.is_CLSER:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_clser(X, logits, y, self.refer_model, self.refer_model_2)
                if torch.rand(1) < self.params.CLSER_freq_1:
                    self.update_stable_model_variables(global_step)
                if torch.rand(1) < self.params.CLSER_freq_2:
                    self.update_plastic_model_variables(global_step)
            elif self.params.is_BaCE:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_distill_BaCE(X, logits, y, idx, self.refer_model, task_id, self.CL_dataset.NUM_TASK)
            elif self.params.is_distill or self.params.is_SmoothE:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_distill(X, logits, y, self.refer_model)
            elif self.params.is_lucir:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_lucir(X, features, logits, y, self.refer_model)
            elif self.params.is_podnet:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_podnet(X, features, logits, y, self.refer_model)
            elif self.params.is_LWF:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_LwF(X, logits, y, self.refer_model, task_id)
            elif self.params.is_EWC:
                total_loss, ce_loss, distill_loss = self.model.batch_loss_EWC(logits, y, self.refer_model, self.fisher, task_id)
            else:
                total_loss, ce_loss, distill_loss = self.model.batch_loss(logits, y)

        # backward
        self.model.train()
        self.optimizer.zero_grad()        
        total_loss.backward()
        self.optimizer.step()

        return total_loss.item(), ce_loss, distill_loss

    def update_stable_model_variables(self,global_step):
        alpha = min(1 - 1 / (global_step + 1),  self.params.CLSER_alpha_1)
        for ema_param, param in zip(self.refer_model.parameters(), self.model.parameters()):
            ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

    def update_plastic_model_variables(self,global_step):
        alpha = min(1 - 1 / (global_step + 1), self.params.CLSER_alpha_2)
        for ema_param, param in zip(self.refer_model_2.parameters(), self.model.parameters()):
            ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

    def begin_epoch(self, task_id, epoch_id):
        if task_id>0 and self.params.task_name=='IC' and self.params.is_BaCE:
            if epoch_id>1:
                beta = self.params.BaCE_beta
                if beta<1.0:
                    for (n_refer, p_refer), (n, p) in zip(self.refer_model.named_parameters(),self.model.named_parameters()):
                        if n_refer!=n:
                            continue
                        device = p_refer.device
                        _p = p.detach().clone().to(device)
                        p_refer.detach().copy_(p_refer.detach().mul(beta)+_p.detach().mul(1-beta))

            else:
                self.refer_model = deepcopy(self.model)
                self.refer_model.eval()

            idx_list, features_matrix, y_list = compute_feature_by_dataloader(self.CL_dataset.data_loader['train'][task_id], 
                                                                                    model=self.refer_model,
                                                                                    is_normalize=True,
                                                                                    return_idx=True)
            features_matrix = torch.tensor(features_matrix)

            tmp2idx, idx2tmp = [], {}
            for tmp_i, (_idx, _feature) in enumerate(zip(idx_list,features_matrix)):
                tmp2idx.append(_idx.item())
                idx2tmp[_idx] = tmp_i
            
            top_k = self.params.BaCE_k 
            knn_dist_matrix, knn_id_matrix = get_match_id(features_matrix,top_k=top_k)
            for tmp_i, idx in enumerate(tmp2idx):
                self.model.knn_id_dict[idx] = np.array([tmp2idx[tmp_j] for tmp_j in knn_id_matrix[tmp_i].numpy()])
                self.model.knn_dist_dict[idx] = knn_dist_matrix[tmp_i].numpy()

            # if self.params.BaCE_is_plot:

            #     from PIL import Image

            #     for plot_i in [1,2,3,4,5,6,7,8,9,10]:
            #         origin_img = Image.fromarray(x_list[plot_i])
            #         origin_img_y = y_list[plot_i]
                    
            #         origin_img.save('imgs/%d_origin_%d.png'%(plot_i, origin_img_y))
            #         for tmp_i, knn_idx in enumerate(knn_id_matrix[plot_i]):
            #             knn_img = Image.fromarray(x_list[knn_idx])
            #             knn_img_y = y_list[knn_idx]
            #             knn_img_dist = knn_dist_matrix[plot_i][tmp_i]
            #             knn_img.save('imgs/%d_knn_%d_%.2f.png'%(plot_i,knn_img_y,knn_img_dist))

    def end_epoch(self, task_id, epoch_id):
        if self.scheduler is not None:
            self.scheduler.step()
            logger.info('Epoch %d, Learning rate = %s'%(epoch_id,self.scheduler.get_last_lr()))

    def meta_mbpa_predict(self, test_data_loader, task_id):
        """
        Using Meta-MBPA test adaptation 

        Args:
            test_data_loader: data_loader for test sample
            task_id: the task id
        """
        assert task_id>0
        # local adaptation 
        gold_line, pred_line = [], []
        with torch.no_grad():
            org_params = torch.cat([torch.reshape(param, [-1]) for param in self.model.parameters()], 0)

        adapt_model = deepcopy(self.model)
        adapt_model.cuda()
        adapt_optimizer = torch.optim.Adam(adapt_model.parameters(), lr=float(self.params.mbpa_lr))

        for _ in range(self.params.mbpa_step):
            adapt_model.train()

            knn_data = self.model.buffer.get_buffer_batch() # [(idx, X, y),...]
            if self.params.task_name == 'NER':
                pass
                # _, knn_X, knn_y = NER_collate_fn(knn_data)
            elif self.params.task_name == 'TC':
                pass
                # _, knn_X, knn_y = TC_collate_fn(knn_data)
            else:
                raise NotImplementedError()
            knn_X, knn_y = knn_X.cuda().detach(), knn_y.cuda().detach()
    
            logits = adapt_model.forward(knn_X)
            classification_loss, _, _ = adapt_model.batch_loss(logits, knn_y)

            new_params = torch.cat([torch.reshape(param, [-1]) for param in adapt_model.parameters()], 0)
            distance_loss = (org_params-new_params).pow(2).sum()

            adapt_loss = self.params.mbpa_lambda*distance_loss+classification_loss

            adapt_optimizer.zero_grad()
            adapt_loss.backward()
            adapt_optimizer.step()

        with torch.no_grad():
            for idx, x, y in test_data_loader: 
                adapt_model.eval()
                logits = adapt_model.forward(x.cuda())
                pred_line.append(logits.view(-1,logits.shape[-1]).argmax(-1).detach().cpu())
                gold_line.append(y.flatten().detach().cpu())

        gold_line = torch.cat(gold_line)
        pred_line = torch.cat(pred_line)

        return gold_line, pred_line
    
    def mbpa_predict(self, test_data_loader, task_id):
        """
        Using MBPA test adaptation (very slow inference time, takes 1 week for 5 datasets) 

        Args:
            test_data_loader: data_loader for test sample
            task_id: the task id
        """
        from transformers import BertModel, BertConfig
        if not hasattr(self, 'mbpa_key_network'):
            bert_config = BertConfig.from_pretrained(self.params.backbone)
            bert_config.output_hidden_states = True
            bert_config.return_dict = True
            mbpa_key_network = BertModel.from_pretrained('bert-base-cased',config=bert_config)
            setattr(self, 'mbpa_key_network',mbpa_key_network)
        self.mbpa_key_network.eval()
        self.mbpa_key_network.cuda()

        # obtain keys for memory samples
        buffer_all = self.model.buffer.get_buffer_all() # [(idx, X, y),...]
        if self.params.task_name == 'NER':
            pass
            # _, buffer_X, _ = NER_collate_fn(buffer_all)
        elif self.params.task_name == 'TC':
            pass
            # _, buffer_X, _ = TC_collate_fn(buffer_all)
        else:
            raise NotImplementedError()
        buffer_keys = []
        with torch.no_grad():
            for buffer_batch_X in torch.chunk(buffer_X,chunks=buffer_X.shape[0]//32,dim=0):
                buffer_batch_X = buffer_batch_X.cuda()
                outputs = self.mbpa_key_network(buffer_batch_X)
                buffer_keys.append(outputs.last_hidden_state[:,0,:].cpu()) # [CLS] feature
            buffer_keys = torch.cat(buffer_keys,dim=0)

        # obtain keys for test batch
        test_keys = []
        with torch.no_grad():
            for idx, x, y in test_data_loader: 
                x, y = x.cuda(), y.cuda()
                outputs = self.mbpa_key_network(x)
                test_keys.append(outputs.last_hidden_state[:,0,:].cpu()) # [CLS] feature
        test_keys = torch.cat(test_keys,dim=0).numpy()   

        # find KNNs
        nbrs = NearestNeighbors(n_jobs=4).fit(buffer_keys)
        knn_idx_all = nbrs.kneighbors(test_keys,n_neighbors=self.params.mbpa_K,return_distance=False)
        total_test = knn_idx_all.shape[0]
        knn_data_all = []
        for i in range(total_test):
            knn_data = self.model.buffer.get_buffer_batch(select_idx=knn_idx_all[i]) # [(idx, X, y),...]
            if self.params.task_name == 'NER':
                pass
                # _, knn_X, knn_y = NER_collate_fn(knn_data)
            elif self.params.task_name == 'TC':
                pass
                # _, knn_X, knn_y = TC_collate_fn(knn_data)
            else:
                raise NotImplementedError()
            knn_data_all.append((knn_X, knn_y))
            
        # local adaptation 
        gold_line, pred_line = [], []
        with torch.no_grad():
            org_params = torch.cat([torch.reshape(param, [-1]) for param in self.model.parameters()], 0)
        
        test_cnt = -1
        begin_time = time.time()
        for idx, x, y in test_data_loader: 
            for i in range(x.shape[0]):
                test_cnt+=1
                adapt_model = deepcopy(self.model)
                
                adapt_model.cuda()
                adapt_optimizer = torch.optim.Adam(adapt_model.parameters(), lr=float(self.params.mbpa_lr))
                knn_X, knn_y = knn_data_all[test_cnt]
                knn_X, knn_y = knn_X.cuda().detach(), knn_y.cuda().detach()

                for _ in range(self.params.mbpa_step):
                    adapt_model.train()
                    new_params = torch.cat([torch.reshape(param, [-1]) for param in adapt_model.parameters()], 0)
                    logits = adapt_model.forward(knn_X)
                    classification_loss, _, _ = adapt_model.batch_loss(logits, knn_y)
                    distance_loss = (org_params-new_params).pow(2).sum()
                    adapt_loss = self.params.mbpa_lambda*distance_loss+classification_loss
                    adapt_optimizer.zero_grad()
                    adapt_loss.backward()
                    adapt_optimizer.step()

                # predict
                with torch.no_grad():
                    adapt_model.eval()
                    logits = adapt_model.forward(x[i].unsqueeze(0).cuda()).squeeze(0)
                    pred_line.append(logits.view(-1,logits.shape[-1]).argmax(-1).detach().cpu().item())
                    gold_line.append(y[i].flatten().detach().cpu().item())
                    torch.cuda.empty_cache()

                if (test_cnt+1)%100==0:
                    logger.info('MBPA test adapt [%d/%d], remaining time %.2f minutes'%(
                        test_cnt+1,
                        total_test,
                        (time.time()-begin_time)/(test_cnt+1)*(total_test-test_cnt-1)/60))

        return gold_line, pred_line


    def save_model(self, save_name, path=''):
        """
        save the best model
        """
        if len(path)>0:
            saved_path = os.path.join(path, str(save_name))
        else:
            saved_path = os.path.join(self.params.dump_path, str(save_name))
        if self.params.is_MEMO:
            torch.save({
                "hidden_dim": self.model.hidden_dim,
                "output_dim": self.model.output_dim,
                "encoder_general": self.model.encoder_general.state_dict(),
                "encoder_specific": self.model.encoder_specific.state_dict(),
                "classifier": self.model.classifier
            }, saved_path)
        else:
            torch.save({
                "hidden_dim": self.model.hidden_dim,
                "output_dim": self.model.output_dim,
                "encoder": self.model.encoder.state_dict(),
                "classifier": self.model.classifier
            }, saved_path)
        logger.info("Best model has been saved to %s" % saved_path)

    def load_model(self, load_name, path=''):
        """
        load the checkpoint
        """
        if len(path)>0:
            load_path = os.path.join(path, str(load_name))
        else:
            load_path = os.path.join(self.params.dump_path, str(load_name))
        ckpt = torch.load(load_path)
        if self.params.is_MEMO:
            self.model.hidden_dim = ckpt['hidden_dim']
            self.model.output_dim = ckpt['output_dim']
            self.model.encoder_general.load_state_dict(ckpt['encoder_general'])
            self.model.encoder_specific.load_state_dict(ckpt['encoder_specific'])
            self.model.classifier = ckpt['classifier']
        else:
            self.model.hidden_dim = ckpt['hidden_dim']
            self.model.output_dim = ckpt['output_dim']
            self.model.encoder.load_state_dict(ckpt['encoder'])
            self.model.classifier = ckpt['classifier']
        logger.info("Model has been load from %s" % load_path)